from typing import Tuple, Callable
import re
import copy
from pathlib import Path
from tqdm import tqdm

import fire
from datasets import load_dataset
from nltk import edit_distance
from sklearn.metrics import accuracy_score
from sklearn.metrics import matthews_corrcoef

from src.util.io import write_jsonl
from src.util.chat import Conversation
from src.util.dl import gen_batch
from src.util.load import load_saiga
from src.util.openai import openai_batch_completion, OpenAIDecodingArguments

HF_DATASET = "RussianNLP/russian_super_glue"


def generate(
    model,
    tokenizer,
    prompts,
    generation_config,
    debug: bool = True
):
    data = tokenizer(
        prompts,
        return_tensors="pt",
        truncation=True,
        padding=True,
    )
    data = {k: v.to(model.device) for k, v in data.items()}
    output_ids = model.generate(
        **data,
        generation_config=generation_config
    )
    outputs = []
    for sample_output_ids, sample_input_ids in zip(output_ids, data["input_ids"]):
        sample_output_ids = sample_output_ids[len(sample_input_ids):]
        sample_output = tokenizer.decode(sample_output_ids, skip_special_tokens=True)
        sample_output = sample_output.replace("</s>", "").strip()
        if debug:
            print(tokenizer.decode(sample_input_ids, skip_special_tokens=True))
            print(sample_output)
            print()
        outputs.append(sample_output)
    return outputs


def predict_saiga_zero_shot(
    model,
    tokenizer,
    generation_config,
    template_path,
    prompts,
    max_prompt_tokens: int = None,
    debug: bool = False
):
    default_conversation = Conversation.from_template(template_path)
    clean_prompts = []
    for prompt in prompts:
        conversation = copy.deepcopy(default_conversation)
        conversation.add_user_message(prompt)
        prompt = conversation.get_prompt(tokenizer, max_tokens=max_prompt_tokens)
        clean_prompts.append(prompt)
    return generate(
        model=model,
        tokenizer=tokenizer,
        prompts=clean_prompts,
        generation_config=generation_config,
        debug=debug
    )


def find_lcs(s1, s2):
    max_lcs = ""
    for i in range(len(s1)):
        for j in range(i + 1, len(s1)):
            ss1 = s1[i:j]
            if ss1 in s2 and len(ss1) > len(max_lcs):
                max_lcs = ss1
    return max_lcs

# DaNetQA


DANETQA_PROMPT = '''Контекст: {passage}

Используя контекст, ответь одним словом на вопрос: {question}'''

DANETQA_YES_RE = re.compile(
    r"^[^\w]*(Выходные данные|Выход|Ответ|Оценка)?[^\w]*(да|верно|правда|может)",
    re.IGNORECASE | re.MULTILINE | re.DOTALL
)
DANETQA_NO_RE = re.compile(
    r"^[^\w]*(Выходные данные|Выход|Ответ|Оценка)?[^\w]*(нет|неверно|неправда|не|ложь|редко)",
    re.IGNORECASE | re.MULTILINE | re.DOTALL
)


def clean_danetqa_response(response):
    result = True
    if bool(DANETQA_YES_RE.match(response)):
        result = True
    elif bool(DANETQA_NO_RE.match(response)):
        result = False
    else:
        print("ERROR! Не удалось найти Да/Нет в ответе модели и преобразовать его в bool:", response)
    return result


def predict_danetqa(
    split,
    predict_func,
    output_path,
    batch_size: int = 4,
    nrows: int = None,
    template: str = DANETQA_PROMPT,
    clean_func: Callable = clean_danetqa_response
):
    records = list(load_dataset(HF_DATASET, "danetqa", split=split))
    if nrows:
        records = records[:nrows]

    prompts = []
    for record in records:
        prompt = template.format(passage=record["passage"], question=record["question"])
        prompts.append(prompt)

    responses = []
    for batch in tqdm(gen_batch(prompts, batch_size), total=len(prompts) // batch_size + 1):
        responses.extend(predict_func(batch))

    labels, predictions = [], []
    for record, response in zip(records, responses):
        prediction = clean_func(response)
        record["prediction"] = prediction
        label = record["label"]
        if label != -1:
            labels.append(label)
            predictions.append(prediction)

    if labels:
        print("danetqa accuracy:", accuracy_score(labels, predictions))

    outputs = []
    for record in records:
        label = str(record["prediction"]).lower()
        outputs.append({"idx": record["idx"], "label": label})
    write_jsonl(outputs, output_path)

    return records

# TERRA


TERRA_PROMPT = '''Текст: {premise} Утверждение: {hypothesis}
Используя текст, ответь одним словом на вопрос: Вероятно ли утверждение при условии остального текста?'''

TERRA_ENTAILMENT_RE = re.compile(
    r"^[^\w]*(Выходные данные|Выход|Ответ|Оценка)?[^\w]*(да|верно|правда|может|являются|вероятно)",
    re.IGNORECASE | re.MULTILINE | re.DOTALL
)
TERRA_NOT_ENTAILMENT_RE = re.compile(
    r"^[^\w]*(Выходные данные|Выход|Ответ|Оценка)?[^\w]*(нет|неверно|неверное|невероятно|не вероятно|не)",
    re.IGNORECASE | re.MULTILINE | re.DOTALL
)


def terra_to_bool(response):
    return response == "entailment"


def clean_terra_response(response):
    result = "not_entailment"
    if bool(TERRA_ENTAILMENT_RE.match(response)):
        result = "entailment"
    elif bool(TERRA_NOT_ENTAILMENT_RE.match(response)):
        result = "not_entailment"
    else:
        print("ERROR! Не удалось найти Да/Нет в ответе модели и преобразовать его в bool", response)
    return result


def predict_terra(
    split,
    predict_func,
    output_path,
    batch_size: int = 8,
    nrows: int = None,
    template: str = TERRA_PROMPT,
    clean_func=clean_terra_response
):
    records = list(load_dataset(HF_DATASET, "terra", split=split))
    if nrows:
        records = records[:nrows]

    prompts = []
    for record in records:
        prompts.append(template.format(
            premise=record["premise"],
            hypothesis=record["hypothesis"]
        ))

    responses = []
    for batch in tqdm(gen_batch(prompts, batch_size), total=len(prompts) // batch_size + 1):
        responses.extend(predict_func(batch))

    labels, predictions = [], []
    for record, response in zip(records, responses):
        prediction = clean_func(response)
        record["prediction"] = prediction
        label = record["label"]
        if label != -1:
            labels.append(1 - label)
            predictions.append(terra_to_bool(prediction))

    if labels:
        print("terra accuracy:", accuracy_score(labels, predictions))

    outputs = [{"idx": r["idx"], "label": r["prediction"]} for r in records]
    write_jsonl(outputs, output_path)

    return records

# RWSD


RWSD_PROMPT = 'Текст: "{text}"\nНа основе текста одним словом ответь на вопрос: К кому или к чему относится местоимение во фразе "{span2}"?'


def clean_rwsd_response(response, span1):
    lcs = find_lcs(span1.lower(), response.lower())
    return len(lcs) >= 3


def predict_rwsd(
    split,
    predict_func,
    output_path,
    batch_size: int = 4,
    nrows: int = None,
    template: str = RWSD_PROMPT,
    clean_func: Callable = clean_rwsd_response
):
    records = list(load_dataset(HF_DATASET, "rwsd", split=split))
    if nrows:
        records = records[:nrows]

    prompts = []
    for record in records:
        prompts.append(template.format(
            text=record["text"],
            span2=record["span2_text"],
            span1=record["span1_text"],
        ))

    responses = []
    for batch in tqdm(gen_batch(prompts, batch_size), total=len(prompts) // batch_size + 1):
        responses.extend(predict_func(batch))

    labels, predictions = [], []
    for record, response in zip(records, responses):
        prediction = clean_func(response, record["span1_text"])
        record["prediction"] = prediction
        label = record["label"]
        if label != -1:
            labels.append(label)
            predictions.append(prediction)

    if labels:
        print("rwsd accuracy:", accuracy_score(labels, predictions))

    outputs = [{"idx": r["idx"], "label": str(r["prediction"])} for r in records]
    write_jsonl(outputs, output_path)

    return records

# MUSERC


MUSERC_SINGLE_PROMPT = """Текст: {text}

Вопрос: {question}

Является ли "{answer}" правильным ответом на этот вопрос? Основываясь на тексте, ответь только "да" или "нет"."""


MUSERC_SINGLE_YES_RE = re.compile(
    r"^[^\w]*(Выходные данные|Выход|Ответ|Оценка)?[^\w]*(да|является)",
    re.IGNORECASE | re.MULTILINE | re.DOTALL
)
MUSERC_SINGLE_NO_RE = re.compile(
    r"^[^\w]*(Выходные данные|Выход|Ответ|Оценка)?[^\w]*(нет|не)",
    re.IGNORECASE | re.MULTILINE | re.DOTALL
)


def clean_muserc_single_response(response):
    result = False
    if bool(MUSERC_SINGLE_YES_RE.match(response)):
        result = True
    elif bool(MUSERC_SINGLE_NO_RE.match(response)):
        result = False
    else:
        print("ERROR! Не удалось найти Да/Нет в ответе модели и преобразовать его в bool:", response)
    return result


def predict_muserc(
    split,
    predict_func,
    output_path,
    batch_size: int = 2,
    nrows: int = None,
    template: str = MUSERC_SINGLE_PROMPT,
    clean_func: Callable = clean_muserc_single_response
):
    records = list(load_dataset(HF_DATASET, "muserc", split=split))
    if nrows:
        records = records[:nrows]

    prompts = list()
    for record in records:
        text, question, answer = record["paragraph"], record["question"], record["answer"]
        answer = answer.rstrip(".")
        prompts.append(template.format(
            text=text,
            question=question,
            answer=answer
        ))

    responses = []
    for batch in tqdm(gen_batch(prompts, batch_size), total=len(prompts) // batch_size + 1):
        responses.extend(predict_func(batch))

    labels, predictions = [], []
    for record, response in zip(records, responses):
        record["prediction"] = clean_func(response)
        if record["label"] != -1:
            labels.append(record["label"])
            predictions.append(record["prediction"])

    if labels:
        print("muserc accuracy:", accuracy_score(labels, predictions))

    outputs = []
    prev_idx = None
    for record in records:
        idx = record["idx"]
        pidx, qidx, aidx = idx["paragraph"], idx["question"], idx["answer"]
        ppidx, pqidx = None, None
        if prev_idx:
            ppidx, pqidx = prev_idx["paragraph"], prev_idx["question"]

        if ppidx != pidx:
            outputs.append({"idx": pidx, "passage": {"questions": []}})
            assert len(outputs) - 1 == pidx
        paragraph = outputs[-1]

        if pqidx != qidx:
            paragraph["passage"]["questions"].append({"idx": qidx, "answers": []})

        question = paragraph["passage"]["questions"][-1]

        answer = {"idx": aidx, "label": int(record["prediction"])}
        question["answers"].append(answer)
        prev_idx = idx

    write_jsonl(outputs, output_path)
    return records

# RUCOS


def rucos_clean_text(text):
    text = " ".join([s.strip().rstrip(".") + "." for s in text.split("@header")]).strip()
    text = " ".join([s.strip().rstrip(".") + "." for s in text.split("@context")]).strip()
    text = " ".join([s.strip().rstrip(".") + "." for s in text.split("@highlight")]).strip()
    text = " ".join([s.strip() for s in text.split("\n") if s.strip()])
    return text


RUCOS_MASK = "[entity]"

RUCOS_PROMPT = """Контекст: {text}
Запрос: {query}

Какое имя человека или название организации или название места должно быть вместо {mask} в запросе? Ответь не более чем 3 словами в соответствии с контекстом."""


def clean_rucos_response(response, entities):
    answers = []
    for answer in entities:
        lcs = find_lcs(response.strip(), answer.strip())
        answers.append((len(lcs), answer))
    return max(answers)[1]


def predict_rucos(
    split,
    predict_func,
    output_path,
    batch_size: int = 4,
    nrows: int = None,
    debug: bool = False,
    template: str = RUCOS_PROMPT,
    clean_func: Callable = clean_rucos_response
):
    records = list(load_dataset(HF_DATASET, "rucos", split=split))
    if nrows:
        records = records[:nrows]

    prompts = list()
    for record in records:
        entities = record["entities"]
        query = record["query"]
        text = rucos_clean_text(record["passage"])
        entities = [e.strip().strip(",") for e in entities]
        query = query.replace("@placeholder", RUCOS_MASK)
        prompts.append(template.format(
            text=text,
            query=query,
            mask=RUCOS_MASK
        ))

    responses = []
    for batch in tqdm(gen_batch(prompts, batch_size), total=len(prompts) // batch_size + 1):
        responses.extend(predict_func(batch))

    correct_count, all_count = 0, 0
    for response, record in zip(responses, records):
        final_response = clean_func(response, record["entities"])
        record["prediction"] = final_response
        answers = record["answers"]
        if answers:
            all_count += 1
            prediction = record["prediction"].strip().lower()
            for answer in answers:
                answer = answer.strip().lower()
                if edit_distance(answer, prediction) <= 2:
                    correct_count += 1
                    break
    if all_count > 0:
        print("rucos accuracy:", correct_count / all_count)

    outputs = [{"idx": r["idx"]["query"], "label": r["prediction"]} for r in records]
    write_jsonl(outputs, output_path)

    return records


# LIDIRUS


LIDIRUS_PROMPT = '''Текст: "{sentence1}"

Используя текст, можно ли сказать, что утверждение "{sentence2}" точно корректно относительно ситуации из текста? Ответь только "да" или "нет".'''

LIDIRUS_ENTAILMENT_RE = re.compile(
    r"^[^\w]*(Выходные данные|Выход|Ответ|Оценка)?[^\w]*(да|верно|правда|может|вероятна|верная)",
    re.IGNORECASE | re.MULTILINE | re.DOTALL
)
LIDIRUS_NOT_ENTAILMENT_RE = re.compile(
    r"^[^\w]*(Выходные данные|Выход|Ответ|Оценка)?[^\w]*(не|нет|неверно|неверное|невероятна|неверная)",
    re.IGNORECASE | re.MULTILINE | re.DOTALL
)


def lidirus_to_bool(response):
    return response == "entailment"


def clean_lidirus_response(response):
    result = "not_entailment"
    if bool(LIDIRUS_ENTAILMENT_RE.match(response)):
        result = "entailment"
    elif bool(LIDIRUS_NOT_ENTAILMENT_RE.match(response)):
        result = "not_entailment"
    else:
        print("ERROR! Не удалось найти Да/Нет в ответе модели и преобразовать его в bool", response)
    return result


def predict_lidirus(
    predict_func,
    output_path,
    batch_size: int = 4,
    nrows: int = None,
    template: str = LIDIRUS_PROMPT,
    clean_func: Callable = clean_lidirus_response
):
    records = list(load_dataset(HF_DATASET, "lidirus", split="test"))
    if nrows:
        records = records[:nrows]

    prompts = [template.format(
        sentence1=r["sentence1"],
        sentence2=r["sentence2"]
    ) for r in records]

    responses = []
    for batch in tqdm(gen_batch(prompts, batch_size), total=len(prompts) // batch_size + 1):
        responses.extend(predict_func(batch))

    labels, predictions = [], []
    for record, response in zip(records, responses):
        prediction = clean_func(response)
        record["prediction"] = prediction
        label = record["label"]
        labels.append(1 - label)
        predictions.append(lidirus_to_bool(prediction))

    print("lidirus accuracy:", accuracy_score(labels, predictions))
    print("lidirus corr:", matthews_corrcoef(labels, predictions))

    outputs = [{"idx": r["idx"], "label": r["prediction"]} for r in records]
    write_jsonl(outputs, output_path)

    return records

# PARUS


PARUS_CAUSE_PROMPT = """Выбери одну наиболее вероятную причину исключительно из двух предложенных вариантов.

Варианты: {choice1}; {choice2}

{premise}, так как..."""

PARUS_EFFECT_PROMPT = """Выбери одно наиболее вероятное следствие исключительно из двух предложенных вариантов.

Варианты: {choice1}; {choice2}

{premise}, поэтому..."""


def predict_parus(
    split,
    predict_func,
    output_path,
    batch_size: int = 12,
    nrows: int = None,
    template_cause: str = PARUS_CAUSE_PROMPT,
    template_effect: str = PARUS_EFFECT_PROMPT
):
    records = list(load_dataset(HF_DATASET, "parus", split=split))
    if nrows:
        records = records[:nrows]

    prompts = []
    for r in records:
        idx = r["idx"]
        c1 = r["choice1"].rstrip(".").lower()
        c2 = r["choice2"].rstrip(".").lower()
        premise = r["premise"].rstrip(".")

        is_cause = r["question"] == "cause"
        template = template_cause if is_cause else template_effect
        prompts.append(template.format(
            premise=premise,
            choice1=c1,
            choice2=c2
        ))

    responses = list()
    for batch in tqdm(gen_batch(prompts, batch_size), total=len(prompts) // batch_size + 1):
        responses.extend(predict_func(batch))

    assert len(responses) == len(records)
    for idx, (response, record) in enumerate(zip(responses, records)):
        response = response.lower()
        c1 = record["choice1"].rstrip(".").lower()
        c2 = record["choice2"].rstrip(".").lower()
        c1_lcs = find_lcs(response, c1)
        c2_lcs = find_lcs(response, c2)
        record["prediction"] = int(len(c2_lcs) > len(c1_lcs))

    if records[0]["label"] != -1:
        y_true, y_pred = [], []
        for r in records:
            y_pred.append(r["prediction"])
            y_true.append(r["label"])
        score = accuracy_score(y_true, y_pred)
        print("parus accuracy:", score)

    outputs = [{"idx": r["idx"], "label": int(r["prediction"])} for r in records]
    write_jsonl(outputs, output_path)

    return records

# RCB


RCB_PROMPT = """Дан текст: "{premise}"

Ответь на вопрос по тексту "да", "нет" или "может быть": {question}"""

RCB_YES_RE = re.compile(
    r"^[^\w]*(Выходные данные|Выход|Ответ|Оценка)?[^\w]*(да|верно|вероятно)",
    re.IGNORECASE | re.MULTILINE | re.DOTALL
)

RCB_NO_RE = re.compile(
    r"^[^\w]*(Выходные данные|Выход|Ответ|Оценка)?[^\w]*(нет|неверно|неверное|невероятно|не)",
    re.IGNORECASE | re.MULTILINE | re.DOTALL
)


def clean_rcb_response(response):
    is_contradiction = bool(RCB_NO_RE.match(response))
    is_entailment = bool(RCB_YES_RE.match(response))
    if is_contradiction:
        return "contradiction"
    if is_entailment:
        return "entailment"
    return "neutral"


def rcb_label2index(label):
    mapping = {
        "entailment": 0,
        "contradiction": 1,
        "neutral": 2
    }
    return mapping[label]


def predict_rcb(
    split,
    predict_func,
    output_path,
    batch_size: int = 8,
    nrows: int = None,
    template: str = RCB_PROMPT,
    clean_func: Callable = clean_rcb_response
):
    records = list(load_dataset(HF_DATASET, "rcb", split=split))
    if nrows:
        records = records[:nrows]

    questions = [record["hypothesis"].rstrip(".") + "?" for record in records]

    prompts = []
    for record, question in zip(records, questions):
        prompts.append(template.format(
            premise=record["premise"],
            question=question
        ))

    responses = []
    for batch in tqdm(gen_batch(prompts, batch_size), total=len(prompts) // batch_size + 1):
        responses.extend(predict_func(batch))

    for r, response in zip(records, responses):
        r["prediction"] = clean_func(response)

    if records[0]["label"] != -1:
        labels = [r["label"] for r in records]
        responses = [rcb_label2index(r["prediction"]) for r in records]
        print("rcb accuracy:", accuracy_score(labels, responses))

    outputs = [{"idx": r["idx"], "label": r["prediction"]} for r in records]
    write_jsonl(outputs, output_path)

    return records


# RUSSE


RUSSE_PROMPT = '''Ответь только "да" или "нет" на вопрос:
В текстовом фрагменте "{sentence1}" и текстовом фрагменте "{sentence2}" означают ли слова "{word}" разное?'''


RUSSE_YES_RE = re.compile(
    r"^[^\w]*(Выходные данные|Выход|Ответ|Оценка)?[^\w]*(да|верно|вероятно|одно)",
    re.IGNORECASE | re.MULTILINE | re.DOTALL
)

RUSSE_NO_RE = re.compile(
    r"^[^\w]*(Выходные данные|Выход|Ответ|Оценка)?[^\w]*(нет|не)",
    re.IGNORECASE | re.MULTILINE | re.DOTALL
)


def clean_russe_response(response):
    if bool(RUSSE_YES_RE.match(response)):
        return 0
    if bool(RUSSE_NO_RE.match(response)):
        return 1
    print("ERROR! Не удалось найти Да/Нет в ответе модели и преобразовать его в bool:", response)
    return 1


def predict_russe(
    split,
    predict_func,
    output_path,
    batch_size: int = 8,
    nrows: int = None,
    template: str = RUSSE_PROMPT,
    clean_func: Callable = clean_russe_response
):
    records = list(load_dataset(HF_DATASET, "russe", split=split))
    if nrows:
        records = records[:nrows]

    prompts = []
    for record in records:
        prompts.append(template.format(
            sentence1=record["sentence1"],
            sentence2=record["sentence2"],
            word=record["word"]
        ))

    responses = []
    for batch in tqdm(gen_batch(prompts, batch_size), total=len(prompts) // batch_size + 1):
        responses.extend(predict_func(batch))

    for r, response in zip(records, responses):
        r["prediction"] = clean_func(response)

    if records[0]["label"] != -1:
        labels = [r["label"] for r in records]
        responses = [r["prediction"] for r in records]
        print("russe accuracy:", accuracy_score(labels, responses))

    outputs = [{
        "idx": r["idx"],
        "label": str(bool(r["prediction"])).lower()
    } for r in records]
    write_jsonl(outputs, output_path)

    return records


ALL_TASKS = ("danetqa", "lidirus", "muserc", "parus", "rcb", "rucos", "russe", "rwsd", "terra")


def main(
    model_name,
    nrows: int = None,
    template_path: str = "internal_prompts/saiga_v2.json",
    split: str = "test",
    predictions_dir: str = "submission",
    debug: bool = False,
    tasks: Tuple[str] = ALL_TASKS
):
    predictions_dir = Path(predictions_dir)

    predict_short = None
    predict_long = None

    if model_name not in ("gpt-4", "gpt-3.5-turbo"):
        model, tokenizer, generation_config = load_saiga(model_name)
        generation_config.no_repeat_ngram_size = 64
        generation_config.temperature = 0.01

        def predict_saiga_zero_shot_bound(batch):
            generation_config.max_new_tokens = 256
            return predict_saiga_zero_shot(
                model=model,
                tokenizer=tokenizer,
                generation_config=generation_config,
                template_path=template_path,
                prompts=batch,
                debug=debug
            )

        def predict_saiga_zero_shot_bound_short(batch):
            generation_config.max_new_tokens = 8
            return predict_saiga_zero_shot(
                model=model,
                tokenizer=tokenizer,
                generation_config=generation_config,
                template_path=template_path,
                prompts=batch,
                debug=debug
            )

        predict_long = predict_saiga_zero_shot_bound
        predict_short = predict_saiga_zero_shot_bound_short

    else:
        def predict_chatgpt(batch):
            messages = [[{"role": "user", "content": prompt}] for prompt in batch]
            responses = openai_batch_completion(messages, model_name=model_name)
            responses = [r.message.content for r in responses]
            if debug:
                for prompt, response in zip(batch, responses):
                    print(prompt)
                    print(response)
                    print()
            return responses

        def predict_chatgpt_short(batch):
            messages = [[{"role": "user", "content": prompt}] for prompt in batch]
            responses = openai_batch_completion(
                messages,
                decoding_args=OpenAIDecodingArguments(max_tokens=16),
                model_name=model_name
            )
            responses = [r.message.content for r in responses]
            if debug:
                for prompt, response in zip(batch, responses):
                    print(prompt)
                    print(response)
                    print()
            return responses

        predict_long = predict_chatgpt
        predict_short = predict_chatgpt_short

    if "danetqa" in tasks:
        predict_danetqa(
            split=split,
            predict_func=predict_short,
            output_path=predictions_dir / "DaNetQA.jsonl",
            nrows=nrows
        )

    if "terra" in tasks:
        predict_terra(
            split=split,
            predict_func=predict_short,
            output_path=predictions_dir / "TERRa.jsonl",
            nrows=nrows
        )

    if "rwsd" in tasks:
        predict_rwsd(
            split=split,
            predict_func=predict_long,
            output_path=predictions_dir / "RWSD.jsonl",
            nrows=nrows
        )

    if "rucos" in tasks:
        predict_rucos(
            split=split,
            predict_func=predict_long,
            output_path=predictions_dir / "RuCoS.jsonl",
            nrows=nrows
        )

    if "lidirus" in tasks:
        predict_lidirus(
            predict_func=predict_short,
            output_path=predictions_dir / "LiDiRus.jsonl",
            nrows=nrows
        )
    if "parus" in tasks:
        predict_parus(
            split=split,
            predict_func=predict_long,
            output_path=predictions_dir / "PARus.jsonl",
            nrows=nrows
        )
    if "rcb" in tasks:
        predict_rcb(
            split=split,
            predict_func=predict_long,
            output_path=predictions_dir / "RCB.jsonl",
            nrows=nrows
        )
    if "russe" in tasks:
        predict_russe(
            split=split,
            predict_func=predict_short,
            output_path=predictions_dir / "RUSSE.jsonl",
            nrows=nrows
        )
    if "muserc" in tasks:
        predict_muserc(
            split=split,
            predict_func=predict_short,
            output_path=predictions_dir / "MuSeRC.jsonl",
            nrows=nrows
        )


if __name__ == "__main__":
    fire.Fire(main)
